#!/usr/bin/env python3
"""
divergence_warm.py

Subset divergence (items 1–40 vs 41–50) using absolute mean wordcount diffs |CPU-GPU|
computed from *Warm runs only* (default W2/W3/W4). This guarantees no cold-start C1
data can leak into the analysis.

Optional: run the same analysis for Alt environment (A1/A2/A3) by using --group Alt.

Usage:
  python divergence_warm.py --base_dir "." --group Warm
  python divergence_warm.py --base_dir "." --group Alt
"""

import argparse
import re
import numpy as np
from pathlib import Path
from scipy import stats

ITEM_SPLIT_RE = re.compile(r"(?m)^\s*(\d{1,2})\.\s")

CONDITIONS = ["BaselineNoJSON", "BaselineWithJSON", "HighRigor", "PushbackStrong"]

def trim_item50_noise(text: str) -> str:
    m = re.search(r"(?i)\n\s*\n\s*please note\b", text)
    if m:
        return text[:m.start()].rstrip()
    return text.rstrip()

def extract_items(text: str) -> dict[int, str]:
    splits = re.split(ITEM_SPLIT_RE, text)
    items = {}
    for i in range(1, len(splits), 2):
        num = int(splits[i])
        content = splits[i + 1].strip()
        items[num] = content

    if 50 in items:
        items[50] = trim_item50_noise(items[50])

    expected = set(range(1, 51))
    if set(items.keys()) != expected:
        raise ValueError(f"Bad item keys. Missing: {sorted(expected - set(items))} Extra: {sorted(set(items) - expected)}")

    return items

def load_mean_wordcounts(filepaths: list[Path]) -> dict[int, float]:
    all_runs = []
    for fp in filepaths:
        text = fp.read_text(encoding="utf-8", errors="ignore")
        items = extract_items(text)
        wc = {k: len(items[k].split()) for k in items}
        all_runs.append(wc)

    mean_wc = {i: float(np.mean([run[i] for run in all_runs])) for i in range(1, 51)}
    return mean_wc

def build_files(base: Path, cond: str, hw: str, tags: list[str]) -> list[Path]:
    return [base / f"{cond}_{hw}_T0_{t}.txt" for t in tags]

def analyze_condition(base: Path, cond: str, tags: list[str]) -> dict:
    cpu_files = build_files(base, cond, "CPU", tags)
    gpu_files = build_files(base, cond, "GPU", tags)

    missing = [str(p) for p in cpu_files + gpu_files if not p.exists()]
    if missing:
        raise FileNotFoundError("Missing files:\n  " + "\n  ".join(missing))

    cpu_means = load_mean_wordcounts(cpu_files)
    gpu_means = load_mean_wordcounts(gpu_files)

    divergence = np.array([abs(cpu_means[i] - gpu_means[i]) for i in range(1, 51)])
    group_1_40 = divergence[:40]
    group_41_50 = divergence[40:]

    t_stat, p_val = stats.ttest_ind(group_1_40, group_41_50, equal_var=False)  # Welch
    u_stat, p_u = stats.mannwhitneyu(group_1_40, group_41_50, alternative="two-sided")

    return {
        "condition": cond,
        "mean_div_1_40": float(np.mean(group_1_40)),
        "mean_div_41_50": float(np.mean(group_41_50)),
        "welch_t": float(t_stat),
        "welch_p": float(p_val),
        "mannwhitney_u": float(u_stat),
        "mannwhitney_p": float(p_u),
    }

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--base_dir", type=str, required=True, help="Directory containing the run .txt files")
    ap.add_argument("--group", choices=["Warm", "Alt"], default="Warm",
                    help="Warm uses W2/W3/W4; Alt uses A1/A2/A3. Cold is intentionally unsupported.")
    ap.add_argument("--warm_tags", nargs="+", default=["W2", "W3", "W4"])
    ap.add_argument("--alt_tags", nargs="+", default=["A1", "A2", "A3"])
    args = ap.parse_args()

    base = Path(args.base_dir)
    tags = args.warm_tags if args.group == "Warm" else args.alt_tags

    print(f"\nSubset divergence analysis (group={args.group}, tags={tags})")
    print("Divergence metric: |CPU_mean_wordcount - GPU_mean_wordcount| per item\n")

    for cond in CONDITIONS:
        r = analyze_condition(base, cond, tags)
        print(f"=== {r['condition']} ===")
        print(f"Mean divergence (1–40):  {r['mean_div_1_40']:.3f}")
        print(f"Mean divergence (41–50): {r['mean_div_41_50']:.3f}")
        print(f"Welch t-test: t={r['welch_t']:.3f}, p={r['welch_p']:.6g}")
        print(f"Mann–Whitney U: U={r['mannwhitney_u']:.3f}, p={r['mannwhitney_p']:.6g}\n")

if __name__ == "__main__":
    main()